import sys
import os

# Add parent folder to sys.path
parent_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_folder)

import torch
import torch.nn as nn
from Record.file_management import load_from_pickle, read_obj_dumps, strip_instance
import argparse
from Vae.data_utils.data_util import _CustomDataParallel
from models.cnn_vae import CNNVAE
import numpy as np
import PIL.Image as Image

from torch.utils.data import DataLoader
import torch.optim as optim

from data_utils.vae_encoding_dataset import ImageEncodingDataset
import glob
from tqdm import tqdm
import time
import wandb
import random


def train_encoding_linear(args, model, optimizer, device):
    if isinstance(model, list):
        state_model, vel_model = model
        state_model.train()
        vel_model.train()
    else:
        model.train()
    mse = nn.MSELoss()
    for epoch in range(args.epochs):
        for batch_idx, (encoding, target) in enumerate(train_loader):
            encoding, target = encoding.to(device), target.to(device)
            optimizer.zero_grad()

            if isinstance(model, list):
                output = torch.cat([state_model(encoding), vel_model(encoding)], dim=1)
            else:
                output = model(encoding)
            loss_s = mse(output[:, :2], target[:, :2])
            loss_v = mse(output[:, 2:4], target[:, 2:4])
            loss = loss_s + loss_v
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0:
                print("Epoch: {} | Batch: {} | S Loss: {}, V Loss: {}".format(epoch, batch_idx, loss_s.item(), loss_v.item()))
                wandb.log({"Train S Loss": loss_s.item(),
                           "Train V Loss": loss_v.item(),
                           "Train Loss": loss.item()})
        if epoch % args.val_freq == 0:
            with torch.no_grad():
                val_s_loss, val_v_loss = 0, 0
                for batch_idx, (encoding, target) in enumerate(val_loader):
                    encoding, target = encoding.to(device), target.to(device)
                    if isinstance(model, list):
                        output = torch.cat([state_model(encoding), vel_model(encoding)], dim=1)
                    else:
                        output = model(encoding)
                    loss_s = mse(output[:, :2], target[:, :2])
                    loss_v = mse(output[:, 2:4], target[:, 2:4])
                    val_s_loss += loss_s.item()
                    val_v_loss += loss_v.item()
                val_s_loss /= len(val_loader)
                val_v_loss /= len(val_loader)
                print("Epoch: {} | Validation S Loss: {}, Validation V Loss: {}".format(epoch, val_s_loss, val_v_loss))
                wandb.log({
                    "Epoch": epoch,
                    "Val S Loss": val_s_loss,
                    "Val V Loss": val_v_loss,
                    "Val Loss": val_s_loss + val_v_loss})
    





if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='placeholder')
    parser.add_argument('--load_rollouts', default="",
                        help='where to load the rollouts from')
    parser.add_argument('--checkpoint_path', default="",
                        help='where to load the model from')
    parser.add_argument('--save_rollouts', default="",
                        help='where to save the rollouts to')
    args = parser.parse_args()
    args.load_rollouts = "datasets/box2d_default"
    args.load_encodings = "datasets/box2d_default/encodings_stack5_vae_z128.pkl"


    # Hyperparameters
    args.latent_dim = 128
    args.batch_size = 2048
    args.lr = 1e-2
    args.epochs = 1000
    args.val_freq = 1
    device = 'cuda'

    # Create a new directory for saving logs
    current_time = time.strftime("%Y%m%d-%H%M%S")
    run_name = current_time + '_' + f"stack5_vae_z128_encoding_linear_2lyr512dimbn_separate_models"
    log_dir = os.path.join('data', run_name)
    os.makedirs(log_dir, exist_ok=True)
    # Initialize W&B project
    wandb.init(project="causal_vae", entity="carltheq", config=args)
    wandb.run.name = run_name

    # Dataset
    obj_data = read_obj_dumps(args.load_rollouts, i=0, rng=-1, filename='object_dumps.txt')
    encodings = load_from_pickle(args.load_encodings)
    print("Loaded encodings")

    # Setup random seed
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42) 


    # Load Data
    train_loader = DataLoader(ImageEncodingDataset(encodings, obj_data, split='train'), batch_size=args.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(ImageEncodingDataset(encodings, obj_data, split='val'), batch_size=args.batch_size, shuffle=False, num_workers=4)
    print("Loaded datasets")

    # Initialize Model
    state_model = nn.Sequential(
        nn.Linear(args.latent_dim, 512), 
        nn.BatchNorm1d(512),
        nn.ReLU(), 
        nn.Linear(512, 2)).to(device)
    vel_model = nn.Sequential(
        nn.Linear(args.latent_dim, 512), 
        nn.BatchNorm1d(512),
        nn.ReLU(), 
        nn.Linear(512, 2)).to(device)
    optimizer = optim.Adam(list(state_model.parameters()) + list(vel_model.parameters()), lr=args.lr)

    # Train Model
    train_encoding_linear(args, [state_model, vel_model], optimizer, device=device)